#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/jit.h>
#include "test/cpp/jit/test_base.h"
#include "torch/csrc/jit/custom_operator.h"

#include <sstream>
#include <string>

namespace torch {
namespace jit {

void testSchemaMatching() {
  {
    RegisterOperators reg({
        Operator(
            "aten::test_vartype(t[] a, t b) -> (t)",
            [](const Node* node) -> Operation {
              return [](Stack& stack) {
                c10::List<double> list;
                double a;
                pop(stack, list, a);
                push(stack, a);
                return 0;
              };
            }),
    });
    script::Module m("m");
    m.define(R"(
      def test(self):
        a = (1.0, 2.0)
        return torch.test_vartype(a, 2.0)
    )");
    auto result = m.run_method("test");
    TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);

    const std::string error_example = R"JIT(
      def test_2(self):
          a = (1.0, 2.0)
          non_float = (1, 1)
          return torch.test_vartype(a, non_float)
    )JIT";

    std::string err = "";
    try {
      m.define(error_example);
    } catch (const std::exception &e) {
      err = e.what();
    }
    TORCH_INTERNAL_ASSERT(err.find("previously matched to type") != std::string::npos);
  }
  {
    RegisterOperators reg({
        Operator(
            "aten::test_vartype2(t a, t[] b) -> (t[])",
            [](const Node* node) -> Operation {
              return [](Stack& stack) {
                double a;
                c10::List<double> list;
                pop(stack, a, list);
                push(stack, a);
                return 0;
              };
            }),
    });
    script::Module m("m");
    m.define(R"JIT(
      def test(self):
          a = (1.0, 2.0)
          return torch.test_vartype2(3.0, a)
    )JIT");
    auto result = m.run_method("test");
    TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);

    static const auto error_exam2 = R"JIT(
      def test_2(self):
          a = (1, 2)
          return torch.test_vartype2(3.0, a)
    )JIT";


    std::string err = "";
    try {
      m.define(error_exam2);
    } catch (const std::exception &e) {
      err = e.what();
    }
    TORCH_INTERNAL_ASSERT(err.find("previously matched to type") != std::string::npos);
  }
}
} // namespace jit
} // namespace torch
